{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lTFtyC9xmB2x"
   },
   "source": [
    "## **Mount Drive**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 24400,
     "status": "ok",
     "timestamp": 1742838638757,
     "user": {
      "displayName": "Андрій Ванджура",
      "userId": "06159425094443245497"
     },
     "user_tz": -120
    },
    "id": "V0CVQ-nmmDzz",
    "outputId": "461607a9-7609-4ea6-c207-47ee6d5ce2d5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mounted at /content/drive\n"
     ]
    }
   ],
   "source": [
    "from google.colab import drive\n",
    "drive.mount(\"/content/drive\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WRl3uDKqzWdJ"
   },
   "source": [
    "## **Installations**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 26483,
     "status": "ok",
     "timestamp": 1742838665236,
     "user": {
      "displayName": "Андрій Ванджура",
      "userId": "06159425094443245497"
     },
     "user_tz": -120
    },
    "id": "7kbv7L8HlG-i",
    "outputId": "893af20e-c1e7-42c1-849e-7b8f170deba3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting kaleido\n",
      "  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl.metadata (15 kB)\n",
      "Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.9/79.9 MB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: kaleido\n",
      "Successfully installed kaleido-0.2.1\n"
     ]
    }
   ],
   "source": [
    "!pip install -U kaleido"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "I_t0KC78l3rh"
   },
   "source": [
    "## **Imports**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "executionInfo": {
     "elapsed": 1018,
     "status": "ok",
     "timestamp": 1742838666255,
     "user": {
      "displayName": "Андрій Ванджура",
      "userId": "06159425094443245497"
     },
     "user_tz": -120
    },
    "id": "Tx65iM6lRfxy"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 22,
     "status": "ok",
     "timestamp": 1742838666283,
     "user": {
      "displayName": "Андрій Ванджура",
      "userId": "06159425094443245497"
     },
     "user_tz": -120
    },
    "id": "lvjDUPzrUelB"
   },
   "outputs": [],
   "source": [
    "from itertools import product, chain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "executionInfo": {
     "elapsed": 27,
     "status": "ok",
     "timestamp": 1742838666313,
     "user": {
      "displayName": "Андрій Ванджура",
      "userId": "06159425094443245497"
     },
     "user_tz": -120
    },
    "id": "PogZS5zMbPod"
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "executionInfo": {
     "elapsed": 4247,
     "status": "ok",
     "timestamp": 1742838670588,
     "user": {
      "displayName": "Андрій Ванджура",
      "userId": "06159425094443245497"
     },
     "user_tz": -120
    },
    "id": "FLYdH4ZGRYLE"
   },
   "outputs": [],
   "source": [
    "import plotly.graph_objects as go\n",
    "import plotly.express as px\n",
    "\n",
    "from plotly.subplots import make_subplots"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jfwAq60dRp8s"
   },
   "source": [
    "## **Plotting functions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1742838670595,
     "user": {
      "displayName": "Андрій Ванджура",
      "userId": "06159425094443245497"
     },
     "user_tz": -120
    },
    "id": "OtcFz9WJRqQz"
   },
   "outputs": [],
   "source": [
    "def training_validation_loss_plot(train_history, validation_history, metric=\"loss\"):\n",
    "    fig = go.Figure()\n",
    "\n",
    "    fig.add_traces([\n",
    "        go.Scatter(\n",
    "            x=train_history.index + 1,\n",
    "            y=train_history[metric],\n",
    "            name=f\"Training {metric.capitalize()}\"),\n",
    "        go.Scatter(\n",
    "            x=validation_history.index + 1,\n",
    "            y=validation_history[metric],\n",
    "            name=f\"Validation {metric.capitalize()}\"),\n",
    "    ])\n",
    "\n",
    "    fig.update_layout(\n",
    "        height=500, width=500,\n",
    "        xaxis_title=\"Epoch\", yaxis_title=f\"{metric.capitalize()}\",\n",
    "        xaxis_range=[0, 19], yaxis_range=[0, 1],\n",
    "        legend=dict(xanchor=\"right\", yanchor=\"top\")\n",
    "    )\n",
    "\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1742838670599,
     "user": {
      "displayName": "Андрій Ванджура",
      "userId": "06159425094443245497"
     },
     "user_tz": -120
    },
    "id": "ioP2hNf2R1aQ"
   },
   "outputs": [],
   "source": [
    "def training_validation_metrics_plot(train_history, validation_history):\n",
    "\n",
    "    fig = make_subplots(rows=2, cols=2)\n",
    "\n",
    "    metrics = [\"loss\", \"AUC\", \"f1-score\", \"fβ-score\"]\n",
    "    rows_indexes = [(i+1, i+1) for i in range(2) for j in range(2)]\n",
    "    cols_indexes = [(j+1, j+1) for i in range(2) for j in range(2)]\n",
    "\n",
    "    for i, (metric, row_indexes, col_indexes) in enumerate(zip(metrics, rows_indexes, cols_indexes)):\n",
    "        fig.add_traces([\n",
    "            go.Scatter(\n",
    "                x=train_history.index + 1,\n",
    "                y=train_history[metric],\n",
    "                name=\"Training\",\n",
    "                line=dict(color=\"#636EFA\"),\n",
    "                showlegend=(i == 0),\n",
    "            ),\n",
    "            go.Scatter(\n",
    "                x=validation_history.index + 1,\n",
    "                y=validation_history[metric],\n",
    "                name=\"Validation\",\n",
    "                line=dict(color=\"#EF553B\"),\n",
    "                showlegend=(i == 0),\n",
    "            ),\n",
    "        ], rows=row_indexes, cols=col_indexes)\n",
    "\n",
    "        fig.update_xaxes(\n",
    "            title=\"Epoch\", range=[1, 20],\n",
    "            row=row_indexes[0], col=col_indexes[0]\n",
    "        )\n",
    "\n",
    "        fig.update_yaxes(\n",
    "            title=f\"{metric.capitalize()}\", range=[0, 1],\n",
    "            row=row_indexes[0], col=col_indexes[0]\n",
    "        )\n",
    "\n",
    "    fig.update_layout(height=750, width=750)\n",
    "\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1742838670614,
     "user": {
      "displayName": "Андрій Ванджура",
      "userId": "06159425094443245497"
     },
     "user_tz": -120
    },
    "id": "Whbrute1VMtO"
   },
   "outputs": [],
   "source": [
    "def testing_metrics_plot(test_results):\n",
    "\n",
    "    metrics = [\"loss\", \"f1-score\", \"fβ-score\"]\n",
    "\n",
    "    fig = go.Figure()\n",
    "\n",
    "    fig.add_trace(\n",
    "        go.Bar(\n",
    "            x=[\"AUC\"], y=test_results[\"AUC_mean\"],\n",
    "            width=0.6, error_y=dict(array=[1.96 * test_results[\"AUC_std\"]]),\n",
    "        )\n",
    "    )\n",
    "\n",
    "    fig.add_trace(\n",
    "        go.Bar(\n",
    "            x=metrics, y=test_results[metrics].loc[0],\n",
    "            width=0.6,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    fig.add_hline(\n",
    "        y=0.9,\n",
    "        line_width=0.75,\n",
    "        line_dash=\"dash\",\n",
    "        line_color=\"green\"\n",
    "    )\n",
    "\n",
    "    fig.update_layout(\n",
    "        height=500, width=500,\n",
    "        yaxis_range=[0, 1],\n",
    "        showlegend=False\n",
    "    )\n",
    "\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EcYCCAHoSpCu"
   },
   "source": [
    "## **Plotting**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1742839192828,
     "user": {
      "displayName": "Андрій Ванджура",
      "userId": "06159425094443245497"
     },
     "user_tz": -120
    },
    "id": "pn9boY45SuHS"
   },
   "outputs": [],
   "source": [
    "def load_data(architecture, lead, format, purpose=\"train\"):\n",
    "    assert purpose in [\"train\", \"validation\", \"test\"], \"Invalid purpose selected.\"\n",
    "\n",
    "    directory = \"testing\" if purpose == \"test\" else \"training\"\n",
    "    history = pd.read_csv(f\"/content/drive/MyDrive/Thesis/{directory}/{format}/{architecture}_{format}_{purpose}.csv\")\n",
    "\n",
    "    return history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ARCHITECTURES = [\"lstm\", \"gru\", \"cnn_gru\"]\n",
    "LEADS = [\"8leads\", \"12leads\"]\n",
    "FORMATS = [\"raw\", \"filtered\"]\n",
    "\n",
    "experiments = chain(\n",
    "    product(ARCHITECTURES, LEADS[1:], FORMATS[:1]),\n",
    "    product(ARCHITECTURES, LEADS, FORMATS[1:])\n",
    ")\n",
    "\n",
    "for (architecture, lead, format) in tqdm(experiments):\n",
    "\n",
    "    train_history = load_data(architecture, lead, format, purpose=\"train\")\n",
    "    validation_history = load_data(architecture, lead, format, purpose=\"validation\")\n",
    "    test_results = load_data(architecture, lead, format, purpose=\"test\")\n",
    "\n",
    "    experiment = \"_\".join([architecture, lead, format])\n",
    "\n",
    "    fig = training_validation_metrics_plot(train_history, validation_history)\n",
    "    fig.write_image(f\"/content/drive/MyDrive/Thesis/visualizations/training_validation_plots/{experiment}.png\")\n",
    "\n",
    "    fig = testing_metrics_plot(test_results)\n",
    "    fig.write_image(f\"/content/drive/MyDrive/Thesis/visualizations/testing_plots/{experiment}.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 15785,
     "status": "ok",
     "timestamp": 1742839216963,
     "user": {
      "displayName": "Андрій Ванджура",
      "userId": "06159425094443245497"
     },
     "user_tz": -120
    },
    "id": "D1akE8R7SlD8",
    "outputId": "6af78a45-da4d-4e1e-c978-2de5b6eec52b"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6/6 [00:15<00:00,  2.62s/it]\n"
     ]
    }
   ],
   "source": [
    "ARCHITECTURES = [\"resnet18\", \"resnet34\", \"resnet50\"]\n",
    "FORMATS = [\"spectrograms\", \"scaleograms\"]\n",
    "\n",
    "experiments = product(ARCHITECTURES, FORMATS)\n",
    "experiments = list(experiments)\n",
    "\n",
    "for (architecture, format) in tqdm(experiments):\n",
    "\n",
    "    train_history = load_data(architecture, None, format, purpose=\"train\")\n",
    "    validation_history = load_data(architecture, None, format, purpose=\"validation\")\n",
    "    test_results = load_data(architecture, None, format, purpose=\"test\")\n",
    "\n",
    "    experiment = \"_\".join([architecture, format])\n",
    "\n",
    "    fig = training_validation_metrics_plot(train_history, validation_history)\n",
    "    fig.write_image(f\"/content/drive/MyDrive/Thesis/visualizations/training_validation_plots/{experiment}.png\")\n",
    "\n",
    "    fig = testing_metrics_plot(test_results)\n",
    "    fig.write_image(f\"/content/drive/MyDrive/Thesis/visualizations/testing_plots/{experiment}.png\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyNCCwX15TUpVysuYj822hgb",
   "collapsed_sections": [
    "I_t0KC78l3rh"
   ],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
